{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c09ada11-d3a6-4837-b4d0-9657106a54b5",
   "metadata": {},
   "source": [
    "# RLSS2023 - DQN Tutorial: Deep Q-Network (DQN)\n",
    "\n",
    "## Part I: DQN Components: Replay Buffer, Q-Network, ...\n",
    "\n",
    "Website: https://rlsummerschool.com/\n",
    "\n",
    "Github repository: https://github.com/araffin/rlss23-dqn/\n",
    "\n",
    "Gymnasium documentation: https://gymnasium.farama.org/\n",
    "\n",
    "## Introduction\n",
    "\n",
    "In this notebook, you will implement the main components [Deep Q-Network (DQN)](https://stable-baselines3.readthedocs.io/en/master/modules/dqn.html) algorithm: a replay buffer, epsilon-greedy exploration strategy and a q-network.\n",
    "\n",
    "In Part II, you will implement the training loop and the DQN gradient update."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fe1a9aa-5735-4614-9a76-031656397899",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for autoformatting\n",
    "# !pip install jupyter-black\n",
    "# %load_ext jupyter_black"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8188798b-daf5-43a7-91ec-a7a922bc2034",
   "metadata": {},
   "source": [
    "## Install Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b55494c-fff2-4459-87e1-e7399afd56d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install git+https://github.com/araffin/rlss23-dqn-tutorial/ --upgrade"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d41f978-b6dd-47aa-baa6-04514109cd0c",
   "metadata": {},
   "source": [
    "## DQN Components"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e2369b55-266c-4dbe-b14d-250ef7386407",
   "metadata": {},
   "source": [
    "## 1. Replay Buffer\n",
    "\n",
    "The replay buffer is one of the main component of DQN. It contains a collection of transitions, the same way we had a fixed dataset of transitions with FQI. However, compared to FQI, this ring buffer is consistently updated with new experience (and old transitions are dropped when the max capicity is reached).\n",
    "\n",
    "<div>\n",
    "    <img src=\"https://araffin.github.io/slides/dqn-tutorial/images/dqn/replay_buffer.png\" width=\"800\"/>\n",
    "</div>\n",
    "\n",
    "To update the Q-Network, DQN samples mini-batches from the replay buffer (vs the whole dataset for FQI).\n",
    "\n",
    "Each mini-batch can be represented using this structure:\n",
    "\n",
    "```python\n",
    "@dataclass\n",
    "class ReplayBufferSamples:\n",
    "    \"\"\"\n",
    "    A dataclass containing transitions from the replay buffer.\n",
    "    \"\"\"\n",
    "\n",
    "    observations: np.ndarray  # same as states in the theory\n",
    "    next_observations: np.ndarray\n",
    "    actions: np.ndarray\n",
    "    rewards: np.ndarray\n",
    "    terminateds: np.ndarray\n",
    "```\n",
    "\n",
    "<div>\n",
    "    <img src=\"https://araffin.github.io/slides/dqn-tutorial/images/dqn/replay_buffer_sampling.png\" width=\"600\"/>\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34722639-6c01-4c5b-ab65-c8d5cb2adaad",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional\n",
    "\n",
    "import numpy as np\n",
    "import torch as th\n",
    "from gymnasium import spaces\n",
    "\n",
    "from dqn_tutorial.dqn.replay_buffer import ReplayBufferSamples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d1b07ed-cde5-41ac-adfc-e1733dfb21bd",
   "metadata": {},
   "source": [
    "### Exercise (15 minutes): write the replay buffer\n",
    "\n",
    "**HINT**: The replay buffer is similar to what you did in the `collect_data()` function in the FQI notebook ;)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aafc38e2-db83-4cd9-aa26-57a04896023f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReplayBuffer:\n",
    "    \"\"\"\n",
    "    A simple replay buffer class to store and sample transitions.\n",
    "\n",
    "    :param buffer_size: Max number of transitions to store\n",
    "    :param observation_space: Observation space of the env,\n",
    "        contains information about the observation type and shape.\n",
    "    :param action_space: Action space of the env,\n",
    "        contains information about the number of actions.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        buffer_size: int,\n",
    "        observation_space: spaces.Box,\n",
    "        action_space: spaces.Discrete,\n",
    "    ) -> None:\n",
    "        # Current position in the ring buffer\n",
    "        self.current_idx = 0\n",
    "        # Replay buffer max capacity\n",
    "        self.buffer_size = buffer_size\n",
    "        # Boolean flag to know when the buffer has reached its maximal capacity\n",
    "        self.is_full = False\n",
    "\n",
    "        self.observation_space = observation_space\n",
    "        self.action_space = action_space\n",
    "        # Create the different buffers\n",
    "        self.observations = np.zeros((buffer_size, *observation_space.shape), dtype=observation_space.dtype)\n",
    "        self.next_observations = np.zeros((buffer_size, *observation_space.shape), dtype=observation_space.dtype)\n",
    "        # The action is an integer\n",
    "        action_dim = 1\n",
    "        self.actions = np.zeros((buffer_size, action_dim), dtype=action_space.dtype)\n",
    "\n",
    "        ### YOUR CODE HERE\n",
    "        \n",
    "        # TODO: create the buffers (numpy arrays) for the rewards (dtype=np.float32)\n",
    "        # and the terminated signals (dtype=bool)\n",
    "        self.rewards = ...\n",
    "        self.terminateds = ...\n",
    "        \n",
    "        ### END OF YOUR CODE\n",
    "\n",
    "    def store_transition(\n",
    "        self,\n",
    "        obs: np.ndarray,\n",
    "        next_obs: np.ndarray,\n",
    "        action: int,\n",
    "        reward: float,\n",
    "        terminated: bool,\n",
    "    ) -> None:\n",
    "        \"\"\"\n",
    "        Store one transition in the buffer.\n",
    "\n",
    "        :param obs: Current observation\n",
    "        :param next_obs: Next observation\n",
    "        :param action: Action taken for the current observation\n",
    "        :param reward: Reward received after taking the action\n",
    "        :param terminated: Whether it is the end of an episode or not\n",
    "            (discarding episode truncation like timeout)\n",
    "        \"\"\"\n",
    "        ### YOUR CODE HERE\n",
    "\n",
    "        # TODO:\n",
    "        # 1. Update the different buffers defined in the __init__\n",
    "        # 2. Update the pointer (`self.current_idx`). Be careful,\n",
    "        #  the pointer needs to be set to zero when reaching the end of the ring buffer.\n",
    "\n",
    "        # Update the buffers to store the new transition\n",
    "\n",
    "        \n",
    "        # Update the pointer\n",
    "\n",
    "        # If the buffer is full, we start from zero again, this is a ring buffer\n",
    "        # you also need to set the flag `is_full` to True (so we know the buffer has reached its max capacity)\n",
    "\n",
    "\n",
    "        ### END OF YOUR CODE\n",
    "\n",
    "    def sample(self, batch_size: int) -> ReplayBufferSamples:\n",
    "        \"\"\"\n",
    "        Sample with replacement `batch_size` transitions from the buffer.\n",
    "\n",
    "        :param batch_size: How many transitions to sample.\n",
    "        :return: Samples from the replay buffer\n",
    "        \"\"\"\n",
    "        ### YOUR CODE HERE\n",
    "\n",
    "        # TODO:\n",
    "        # 1. Retrieve the upper bound (max index that can be sampled)\n",
    "        #  it corresponds to `self.buffer_size` when the ring buffer is full (we can samples all indices)\n",
    "        # 2. Sample `batch_size` indices with replacement from the buffer\n",
    "        # (in the range [0, upper_bound[ ), numpy has a method `np.random.randint` for that ;)\n",
    "        upper_bound = ...\n",
    "        batch_indices = ...\n",
    "        \n",
    "        ### END OF YOUR CODE\n",
    "\n",
    "        return ReplayBufferSamples(\n",
    "            self.observations[batch_indices],\n",
    "            self.next_observations[batch_indices],\n",
    "            self.actions[batch_indices],\n",
    "            self.rewards[batch_indices],\n",
    "            self.terminateds[batch_indices],\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7f4c768-2c3b-4eac-b025-5ed04eaa241c",
   "metadata": {},
   "source": [
    "testing the replay buffer, without reaching max capacity:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60ec422f-2edd-40df-8da2-c1093e1da38c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "\n",
    "env = gym.make(\"CartPole-v1\")\n",
    "\n",
    "buffer_size = 1000\n",
    "buffer = ReplayBuffer(buffer_size, env.observation_space, env.action_space)\n",
    "obs, _ = env.reset()\n",
    "# Fill the buffer\n",
    "for _ in range(500):\n",
    "    action = int(env.action_space.sample())\n",
    "    next_obs, reward, terminated, truncated, _ = env.step(action)\n",
    "    # Store new transition in the replay buffer\n",
    "    buffer.store_transition(obs, next_obs, action, float(reward), terminated)\n",
    "    # Update current observation\n",
    "    obs = next_obs\n",
    "\n",
    "    done = terminated or truncated\n",
    "    if done:\n",
    "        obs, _ = env.reset()\n",
    "\n",
    "assert not buffer.is_full\n",
    "assert buffer.current_idx == 500\n",
    "samples = buffer.sample(batch_size=10)\n",
    "assert len(samples.observations) == 10\n",
    "assert samples.actions.shape == (10, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc01ab42-cc8f-4ae3-90de-3cd99de3db31",
   "metadata": {},
   "source": [
    "Testing the replay buffer when reaching max capacity:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94e9196c-bf2a-4433-bc36-3100a145a318",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fill the buffer completely\n",
    "for _ in range(1000):\n",
    "    action = int(env.action_space.sample())\n",
    "    next_obs, reward, terminated, truncated, _ = env.step(action)\n",
    "    buffer.store_transition(obs, next_obs, action, float(reward), terminated)\n",
    "    # Update current observation\n",
    "    obs = next_obs\n",
    "\n",
    "    done = terminated or truncated\n",
    "    if done:\n",
    "        obs, _ = env.reset()\n",
    "\n",
    "assert buffer.is_full\n",
    "# We did a full loop\n",
    "assert buffer.current_idx == 500\n",
    "# Check sampling with replacement\n",
    "# we should be able to sample more transitions\n",
    "# than the capacity of the ring buffer\n",
    "samples = buffer.sample(batch_size=1001)\n",
    "assert len(samples.observations) == 1001\n",
    "assert samples.actions.shape == (1001, 1)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "34619317-64d8-4872-a25e-750caa0f1588",
   "metadata": {},
   "source": [
    "## 2. Q-Network\n",
    "\n",
    "<div>\n",
    "    <img src=\"https://araffin.github.io/slides/dqn-tutorial/images/dqn/q_network.png\" width=\"600\"/>\n",
    "</div>\n",
    "\n",
    "Similiar to FQI, the Q-Network estimates $Q_\\pi(s_t, a_t)$, however, instead of taking the action as input, it outputs q-values for all possible actions as output."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a0de4e1-c605-4b8e-bc79-585de22dd4e6",
   "metadata": {},
   "source": [
    "### Exercise (6 minutes): write the Q-Network using PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7210df0c-4e0c-4dc0-a893-b5a1ccb815d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Type\n",
    "\n",
    "import torch as th\n",
    "import torch.nn as nn\n",
    "from gymnasium import spaces\n",
    "\n",
    "\n",
    "class QNetwork(nn.Module):\n",
    "    \"\"\"\n",
    "    A Q-Network for the DQN algorithm\n",
    "    to estimate the q-value for a given observation.\n",
    "\n",
    "    :param observation_space: Observation space of the env,\n",
    "        contains information about the observation type and shape.\n",
    "    :param action_space: Action space of the env,\n",
    "        contains information about the number of actions.\n",
    "    :param n_hidden_units: Number of units for each hidden layer.\n",
    "    :param activation_fn: Activation function (ReLU by default)\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        observation_space: spaces.Box,\n",
    "        action_space: spaces.Discrete,\n",
    "        n_hidden_units: int = 64,\n",
    "        activation_fn: Type[nn.Module] = nn.ReLU,\n",
    "    ) -> None:\n",
    "        super().__init__()\n",
    "        # Assume 1d space\n",
    "        obs_dim = observation_space.shape[0]\n",
    "\n",
    "        ### YOUR CODE HERE\n",
    "        # TODO:\n",
    "        # 1. Retrieve the number of discrete actions,\n",
    "        # that will be the number of ouputs of the q-network\n",
    "        # 2. Create the q-network, it will be a two layers fully-connected\n",
    "        # neural network which take the state (observation) as input\n",
    "        # and outputs the q-values for all possible actions\n",
    "\n",
    "        # Retrieve the number of discrete actions (using attribute `n` from `action_space`)\n",
    "\n",
    "        \n",
    "        # Create the q network: a 2 fully connected hidden layers with `n_hidden_units` each\n",
    "        # with `activation_fn` for the activation function after each hidden layer.\n",
    "        # You should use `nn.Sequential` (combine several layers to create a network)\n",
    "        # `nn.Linear` (fully connected layer) from PyTorch.\n",
    "        self.q_net = ...\n",
    "\n",
    "        ### END OF YOUR CODE\n",
    "\n",
    "    def forward(self, observations: th.Tensor) -> th.Tensor:\n",
    "        \"\"\"\n",
    "        :param observations: A batch of observation (batch_size, obs_dim)\n",
    "        :return: The Q-values for the given observations\n",
    "            for all the action (batch_size, n_actions)\n",
    "        \"\"\"\n",
    "        return self.q_net(observations)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d2d22ad-4a38-49ab-849d-276b233c972a",
   "metadata": {},
   "source": [
    "To test your Q-Network:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6384a47-4f34-4ae6-96c7-38107d097e50",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make(\"CartPole-v1\")\n",
    "q_net = QNetwork(env.observation_space, env.action_space)\n",
    "\n",
    "print(q_net)\n",
    "assert isinstance(q_net.q_net, nn.Sequential)\n",
    "assert len(q_net.q_net) == 5\n",
    "assert isinstance(q_net.q_net[0], nn.Linear)\n",
    "assert isinstance(q_net.q_net[1], nn.ReLU)\n",
    "assert isinstance(q_net.q_net[-1], nn.Linear)\n",
    "\n",
    "obs, _ = env.reset()\n",
    "\n",
    "with th.no_grad():\n",
    "    # Create the input to the Q-Network\n",
    "    obs_tensor = th.as_tensor(obs[np.newaxis, ...])\n",
    "    q_values = q_net(obs_tensor)\n",
    "\n",
    "    assert q_values.shape == (1, 2)\n",
    "\n",
    "    best_action = q_values.argmax().item()\n",
    "\n",
    "    assert isinstance(best_action, int)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61ba21ac-b2c3-4d23-bf5a-6ea1ffd62ec7",
   "metadata": {},
   "source": [
    "## 3. Epsilon-greedy data collection"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "4a3fae51-6330-495c-96dd-46616d3c71ed",
   "metadata": {},
   "source": [
    "### Exercise (6 minutes): write the epsilon-greedy action selection strategy\n",
    "\n",
    "<div>\n",
    "    <img src=\"https://araffin.github.io/slides/dqn-tutorial/images/dqn/epsilon_greedy.png\" width=\"600\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f66aafe-b4eb-4878-af9a-b83b29e0d78a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def epsilon_greedy_action_selection(\n",
    "    q_net: QNetwork,\n",
    "    observation: np.ndarray,\n",
    "    exploration_rate: float,\n",
    "    action_space: spaces.Discrete,\n",
    "    device: str = \"cpu\",\n",
    ") -> int:\n",
    "    \"\"\"\n",
    "    Select an action according to an espilon-greedy policy:\n",
    "    with a probability of epsilon (`exploration_rate`),\n",
    "    sample a random action, otherwise follow the best known action\n",
    "    according to the q-value.\n",
    "\n",
    "    :param observation: A single observation.\n",
    "    :param q_net: Q-network for estimating the q value\n",
    "    :param exploration_rate: Current rate of exploration (in [0, 1], 0 means no exploration),\n",
    "        probability to select a random action,\n",
    "        this is \"epsilon\".\n",
    "    :param action_space: Action space of the env,\n",
    "        contains information about the number of actions.\n",
    "    :param device: PyTorch device\n",
    "    :return: An action selected according to the epsilon-greedy policy.\n",
    "    \"\"\"\n",
    "    ### YOUR CODE HERE\n",
    "    # TODO:\n",
    "    # 1. Toss a biased coin (you can use `np.random.rand()`)\n",
    "    # to decide if the agent should take a random action or not\n",
    "    # (with probability p=`exploration_rate`)\n",
    "    # 2. Either take a random action (sample the action space)\n",
    "    # or follow the greedy policy (take the action with the highest q-value)\n",
    "\n",
    "    if ...:\n",
    "        # Random action\n",
    "        action = ...\n",
    "    else:\n",
    "        # Greedy action\n",
    "        # We do not need to compute the gradient, so we use `with th.no_grad():`\n",
    "        with th.no_grad():\n",
    "            # Convert the input to PyTorch tensor and add a batch dimension (obs_dim,) -> (1, obs_dim)\n",
    "            # you can use `th.as_tensor` and numpy `[np.newaxis, ...]`\n",
    "\n",
    "            # Compute q values for all actions\n",
    "\n",
    "            # Greedy policy: select action with the highest q value\n",
    "            # you can use PyTorch `.argmax()` for that\n",
    "            action = ...\n",
    "            \n",
    "    ### END OF YOUR CODE\n",
    "\n",
    "    return action"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db4b5ef9-1a60-4af5-ae50-566b0776b62c",
   "metadata": {},
   "source": [
    "### Exercise (6 minutes): Collect one transition and store it in the replay buffer\n",
    "\n",
    "**HINT**: This is a simplified version of what you did when writing `collect_data()` in the FQI notebook. The main difference here is that you will be using an epsilon-greedy policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c6823ed-8f88-48dc-a5a9-4e9d260603ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def collect_one_step(\n",
    "    env: gym.Env,\n",
    "    q_net: QNetwork,\n",
    "    replay_buffer: ReplayBuffer,\n",
    "    obs: np.ndarray,\n",
    "    exploration_rate: float = 0.1,\n",
    "    verbose: int = 0,\n",
    ") -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Collect one transition and fill the replay buffer following an epsilon greedy policy.\n",
    "\n",
    "    :param env: The environment object.\n",
    "    :param q_net: Q-network for estimating the q value\n",
    "    :param replay_buffer: Replay buffer to store the new transitions.\n",
    "    :param obs: The current observation.\n",
    "    :param exploration_rate: Current rate of exploration (in [0, 1], 0 means no exploration),\n",
    "        probability to select a random action,\n",
    "        this is \"epsilon\".\n",
    "    :param verbose: The verbosity level (1 to print some info).\n",
    "    :return: The last observation (important when collecting data multiple times).\n",
    "    \"\"\"\n",
    "    ### YOUR CODE HERE\n",
    "\n",
    "    # Select an action following an epsilon-greedy policy\n",
    "    # you should use `epsilon_greedy_action_selection()`\n",
    "\n",
    "    # Step in the env\n",
    "\n",
    "    # Store the transition in the replay buffer\n",
    "\n",
    "    # Update current observation\n",
    "\n",
    "    \n",
    "    ### END OF YOUR CODE\n",
    "\n",
    "    if \"episode\" in info and verbose >= 1:\n",
    "        print(f\"Episode return={float(info['episode']['r']):.2f} length={int(info['episode']['l'])}\")\n",
    "\n",
    "    done = terminated or truncated\n",
    "    if done:\n",
    "        # Don't forget to reset the env at the end of an episode\n",
    "        obs, _ = env.reset()\n",
    "    return obs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e72093a8-b227-40e5-a104-191b086e33c2",
   "metadata": {},
   "source": [
    "Let's test the data collection:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "045b251a-9ace-450b-afa9-7ee4128cdef6",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make(\"CartPole-v1\")\n",
    "q_net = QNetwork(env.observation_space, env.action_space)\n",
    "buffer = ReplayBuffer(2000, env.observation_space, env.action_space)\n",
    "\n",
    "obs, _ = env.reset()\n",
    "for _ in range(1000):\n",
    "    obs = collect_one_step(env, q_net, buffer, obs, exploration_rate=0.1)\n",
    "    \n",
    "# Check current buffer position\n",
    "assert buffer.current_idx == 1000\n",
    "\n",
    "# Collect more data\n",
    "for _ in range(1000):\n",
    "    obs = collect_one_step(env, q_net, buffer, obs, exploration_rate=0.1)\n",
    "    \n",
    "# Buffer is full\n",
    "assert buffer.current_idx == 0\n",
    "assert buffer.is_full"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "fb9bf659-e0e3-42f7-b774-1c989bfc6c8f",
   "metadata": {},
   "source": [
    "### Bonus: Exploration Schedule\n",
    "\n",
    "<div>\n",
    "    <img src=\"https://araffin.github.io/slides/dqn-tutorial/images/dqn/linear_schedule.png\" width=\"600\"/>\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "251042a1-0250-4559-8f75-db5c7fd67d2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def linear_schedule(initial_value: float, final_value: float, current_step: int, max_steps: int) -> float:\n",
    "    \"\"\"\n",
    "    Linear schedule for the exploration rate (epsilon).\n",
    "    Note: we clip the value so the schedule is constant after reaching the final value\n",
    "    at `max_steps`.\n",
    "\n",
    "    :param initial_value: Initial value of the schedule.\n",
    "    :param final_value: Final value of the schedule.\n",
    "    :param current_step: Current step of the schedule.\n",
    "    :param max_steps: Maximum number of steps of the schedule.\n",
    "    :return: The current value of the schedule.\n",
    "    \"\"\"\n",
    "    ### YOUR CODE HERE\n",
    "    \n",
    "    # Compute current progress (in [0, 1], 0 being the start)\n",
    "\n",
    "    # Clip the progress so the schedule is constant after reaching the final value\n",
    "    current_value = ...\n",
    "    \n",
    "    ### END OF YOUR CODE\n",
    "    \n",
    "    return current_value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59df7864-f45f-4503-9d21-b07dc22feece",
   "metadata": {},
   "source": [
    "To test the linear schedule:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6f95256-6a76-4323-849e-886741da0653",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Linear schedule\n",
    "exploration_initial_eps = 1.0\n",
    "exploration_final_eps = 0.01\n",
    "exploration_rate = exploration_initial_eps\n",
    "n_steps = 100\n",
    "for step in range(n_steps + 1):\n",
    "    exploration_rate = linear_schedule(exploration_initial_eps, exploration_final_eps, step, n_steps)\n",
    "    if step == 0:\n",
    "        assert exploration_rate == exploration_initial_eps\n",
    "\n",
    "    obs = collect_one_step(env, q_net, buffer, obs, exploration_rate=exploration_rate)\n",
    "\n",
    "assert np.allclose(exploration_rate, exploration_final_eps)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69a60230-6477-4318-8d60-10f6eada6064",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "In this notebook, you have seen how to implement the main components of DQN:\n",
    "- how to create a replay buffer\n",
    "- how to sample actions according to an espilon-greedy exploration strategy\n",
    "- how to create a Q-Network\n",
    "\n",
    "\n",
    "In part II, we will put everything together to finalize DQN implementation.\n",
    "We will see how to implement the DQN update using gradient descent and how to implement the training loop that alternates between data collection and updating the Q-Network. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e6a3b3e-03fd-4653-b5e8-98b0ea7c6564",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
